{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77gENRVX40S7"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "d8jyt37T42Vf"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hRTa3Ee15WsJ"
      },
      "source": [
        "# Transfer Learning Using Pretrained ConvNets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dQHMcypT3vDT"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r1/tutorials/images/transfer_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r1/tutorials/images/transfer_learning.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dQHMcypT3vDT"
      },
      "source": [
        "> Note: This is an archived TF1 notebook. These are configured\n",
        "to run in TF2's \n",
        "[compatbility mode](https://www.tensorflow.org/guide/migrate)\n",
        "but will run in TF1 as well. To use TF1 in Colab, use the\n",
        "[%tensorflow_version 1.x](https://colab.research.google.com/notebooks/tensorflow_version.ipynb)\n",
        "magic."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2X4KyhORdSeO"
      },
      "source": [
        "In this tutorial we will discuss how to classify cats vs dogs images by using transfer learning from a pre-trained network. This will allows us to get higher accuracies than we saw by training our network from scratch.\n",
        "\n",
        "A **pre-trained model** is a saved network that was previously trained on a large dataset, typically on a large-scale image-classification task. We can either use the pretrained model as it is or transfer learning using the pretrained ConvNets. The intuition behind **transfer learning** is that if this model trained on a large and general enough dataset, this model will effectively serve as a generic model of the visual world. We can leverage these learned feature maps without having to train a large model on a large dataset by using these models as the basis of our own model specific to our task. There are 2 scenarios of transfer learning using a pretrained model:\n",
        "\n",
        "1. **Feature Extraction** - use the representations of learned by a previous network to extract meaningful features from new samples. We simply add a new classifier, which will be trained from scratch, on top of the pretrained model so that we can repurpose the feature maps learned previously for our dataset. **Do we use the entire pretrained model or just the convolutional base?** - We use the feature extraction portion of these pretrained ConvNets (convolutional base) since they are likely to be generic features and learned concepts over a picture. However, the classification part of the pretrained model is often specific to original classification task, and subsequently specific to the set of classes on which the model was trained.\n",
        "2. **Fine-Tuning** - unfreezing a few of the top layers of a frozen model base used for feature extraction, and jointly training both the newly added classifier layers as well as the last layers of the frozen model. This allows us to \"fine tune\" the higher order feature representations in addition to our final classifier in order to make them more relevant for the specific task involved.\n",
        "\n",
        "**We will follow the general machine learning workflow:**\n",
        "\n",
        "1. Examine and understand data\n",
        "2. Build an input pipeline - using Keras ImageDataGenerator as we did in the image classification tutorial\n",
        "3. Compose our model\n",
        "    * Load in our pretrained model (and pretrained weights)\n",
        "    * Stack our classification layers on top\n",
        "4. Train our model\n",
        "5. Evaluate model\n",
        "\n",
        "We will see an example of using the pre-trained ConvNet as the feature extraction and then fine-tune to train the last few layers of the base model.\n",
        "\n",
        "**Audience:** This post is geared towards beginners with some Keras API and ML background. To get the most out of this post, you should have some basic ML background, know what CNNs are, and be familiar with the Keras Sequential API.\n",
        "\n",
        "**Time Estimated**: 30 minutes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iBMcobPHdD8O"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "import tensorflow.compat.v1 as tf\n",
        "\n",
        "from tensorflow import keras\n",
        "print(\"TensorFlow version is \", tf.__version__)\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.image as mpimg"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v77rlkCKW0IJ"
      },
      "source": [
        "## Data preprocessing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aXzwKdouXf1h"
      },
      "source": [
        "### Download data - cats_and_dogs_filtered.zip\n",
        "We will download a filtered version of Kaggle's [Dogs vs Cats](https://www.kaggle.com/c/dogs-vs-cats/data) dataset. Then store the downloaded zip file to the \"/tmp/\" directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nRnO59Kr6enO"
      },
      "outputs": [],
      "source": [
        "zip_file = tf.keras.utils.get_file(origin=\"https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip\",\n",
        "                                   fname=\"cats_and_dogs_filtered.zip\", extract=True)\n",
        "base_dir, _ = os.path.splitext(zip_file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9_6h-c5EXN91"
      },
      "source": [
        "### Prepare training and validation cats and dogs datasets\n",
        "Create the training and validation directories for cats datasets and dog datasets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RWcldM4TXLen"
      },
      "outputs": [],
      "source": [
        "train_dir = os.path.join(base_dir, 'train')\n",
        "validation_dir = os.path.join(base_dir, 'validation')\n",
        "\n",
        "# Directory with our training cat pictures\n",
        "train_cats_dir = os.path.join(train_dir, 'cats')\n",
        "print ('Total training cat images:', len(os.listdir(train_cats_dir)))\n",
        "\n",
        "# Directory with our training dog pictures\n",
        "train_dogs_dir = os.path.join(train_dir, 'dogs')\n",
        "print ('Total training dog images:', len(os.listdir(train_dogs_dir)))\n",
        "\n",
        "# Directory with our validation cat pictures\n",
        "validation_cats_dir = os.path.join(validation_dir, 'cats')\n",
        "print ('Total validation cat images:', len(os.listdir(validation_cats_dir)))\n",
        "\n",
        "# Directory with our validation dog pictures\n",
        "validation_dogs_dir = os.path.join(validation_dir, 'dogs')\n",
        "print ('Total validation dog images:', len(os.listdir(validation_dogs_dir)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wvidPx6jeFzf"
      },
      "source": [
        "### Create Image Data Generator with Image Augmentation\n",
        "\n",
        "We will use ImageDataGenerator to rescale the images.\n",
        "\n",
        "To create the train generator, specify where the train dataset directory, image size, batch size and binary classification mode.\n",
        "\n",
        "The validation generator is created the same way."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y3PM6GVHcC31"
      },
      "outputs": [],
      "source": [
        "image_size = 160 # All images will be resized to 160x160\n",
        "batch_size = 32\n",
        "\n",
        "# Rescale all images by 1./255 and apply image augmentation\n",
        "train_datagen = keras.preprocessing.image.ImageDataGenerator(\n",
        "                rescale=1./255)\n",
        "\n",
        "validation_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)\n",
        "\n",
        "# Flow training images in batches of 20 using train_datagen generator\n",
        "train_generator = train_datagen.flow_from_directory(\n",
        "                train_dir,  # Source directory for the training images\n",
        "                target_size=(image_size, image_size),\n",
        "                batch_size=batch_size,\n",
        "                # Since we use binary_crossentropy loss, we need binary labels\n",
        "                class_mode='binary')\n",
        "\n",
        "# Flow validation images in batches of 20 using test_datagen generator\n",
        "validation_generator = validation_datagen.flow_from_directory(\n",
        "                validation_dir, # Source directory for the validation images\n",
        "                target_size=(image_size, image_size),\n",
        "                batch_size=batch_size,\n",
        "                class_mode='binary')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OkH-kazQecHB"
      },
      "source": [
        "## Create the base model from the pre-trained ConvNets\n",
        "We will create the base model from the **MobileNet V2** model developed at Google, and pre-trained on the ImageNet dataset, a large dataset of 1.4M images and 1000 classes of web images. This is a powerful model. Let's see what the features that it has learned can do for our cat vs. dog problem.\n",
        "\n",
        "First, we need to pick which intermediate layer of MobileNet V2 we will use for feature extraction. A common practice is to use the output of the very last layer before the flatten operation, the so-called \"bottleneck layer\". The reasoning here is that the following fully-connected layers will be too specialized to the task the network was trained on, and thus the features learned by these layers won't be very useful for a new task. The bottleneck features, however, retain much generality.\n",
        "\n",
        "Let's instantiate an MobileNet V2 model pre-loaded with weights trained on ImageNet. By specifying the **include_top=False** argument, we load a network that doesn't include the classification layers at the top, which is ideal for feature extraction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "19IQ2gqneqmS"
      },
      "outputs": [],
      "source": [
        "IMG_SHAPE = (image_size, image_size, 3)\n",
        "\n",
        "# Create the base model from the pre-trained model MobileNet V2\n",
        "base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,\n",
        "                                               include_top=False,\n",
        "                                               weights='imagenet')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rlx56nQtfe8Y"
      },
      "source": [
        "## Feature extraction\n",
        "We will freeze the convolutional base created from the previous step and use that as a feature extractor, add a classifier on top of it and train the top-level classifier."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CnMLieHBCwil"
      },
      "source": [
        "### Freeze the convolutional base\n",
        "It's important to freeze the convolutional based before we compile and train the model. By freezing (or setting `layer.trainable = False`), we prevent the weights in these layers from being updated during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OTCJH4bphOeo"
      },
      "outputs": [],
      "source": [
        "base_model.trainable = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KpbzSmPkDa-N"
      },
      "outputs": [],
      "source": [
        "# Let's take a look at the base model architecture\n",
        "base_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wdMRM8YModbk"
      },
      "source": [
        "#### Add a classification head"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0iqnBeZrfoIc"
      },
      "source": [
        "Now let's add a few layers on top of the base model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eApvroIyn1K0"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "  base_model,\n",
        "  keras.layers.GlobalAveragePooling2D(),\n",
        "  keras.layers.Dense(1, activation='sigmoid')\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g0ylJXE_kRLi"
      },
      "source": [
        "### Compile the model\n",
        "\n",
        "You must compile the model before training it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RpR8HdyMhukJ"
      },
      "outputs": [],
      "source": [
        "model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),\n",
        "              loss='binary_crossentropy',\n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I8ARiyMFsgbH"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lxOcmVr0ydFZ"
      },
      "source": [
        "These 1.2K trainable parameters are divided among 2 TensorFlow `Variable` objects, the weights and biases of the two dense layers:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "krvBumovycVA"
      },
      "outputs": [],
      "source": [
        "len(model.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RxvgOYTDSWTx"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "After training for 10 epochs, we are able to get ~94% accuracy.\n",
        "\n",
        "If you have more time, train it to convergence (50 epochs, ~96% accuracy)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Om4O3EESkab1"
      },
      "outputs": [],
      "source": [
        "epochs = 10\n",
        "steps_per_epoch = train_generator.n // batch_size\n",
        "validation_steps = validation_generator.n // batch_size\n",
        "\n",
        "history = model.fit_generator(train_generator,\n",
        "                              steps_per_epoch = steps_per_epoch,\n",
        "                              epochs=epochs,\n",
        "                              workers=4,\n",
        "                              validation_data=validation_generator,\n",
        "                              validation_steps=validation_steps)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hd94CKImf8vi"
      },
      "source": [
        "### Learning curves\n",
        "\n",
        "Let's take a look at the learning curves of the training and validation accuracy / loss, when using the MobileNet V2 base model as a fixed feature extractor."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l7HOsQTPNgO9"
      },
      "source": [
        "If you train to convergence (`epochs=50`) the resulting graph should look like this:\n",
        "\n",
        "![Before fine tuning, the model reaches 96% accuracy](https://github.com/tensorflow/docs/blob/master/site/en/r1/tutorials/images/images/before_fine_tuning.png?raw=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "53OTCh3jnbwV"
      },
      "outputs": [],
      "source": [
        "acc = history.history['accuracy']\n",
        "val_acc = history.history['val_accuracy']\n",
        "\n",
        "loss = history.history['loss']\n",
        "val_loss = history.history['val_loss']\n",
        "\n",
        "plt.figure(figsize=(8, 8))\n",
        "plt.subplot(2, 1, 1)\n",
        "plt.plot(acc, label='Training Accuracy')\n",
        "plt.plot(val_acc, label='Validation Accuracy')\n",
        "plt.legend(loc='lower right')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.ylim([min(plt.ylim()),1])\n",
        "plt.title('Training and Validation Accuracy')\n",
        "\n",
        "plt.subplot(2, 1, 2)\n",
        "plt.plot(loss, label='Training Loss')\n",
        "plt.plot(val_loss, label='Validation Loss')\n",
        "plt.legend(loc='upper right')\n",
        "plt.ylabel('Cross Entropy')\n",
        "plt.ylim([0,max(plt.ylim())])\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CqwV-CRdS6Nv"
      },
      "source": [
        "## Fine tuning\n",
        "In our feature extraction experiment, we were only training a few layers on top of an MobileNet V2 base model. The weights of the pre-trained network were **not** updated during training. One way to increase performance even further is to \"fine-tune\" the weights of the top layers of the pre-trained model alongside the training of the top-level classifier. The training process will force the weights to be tuned from generic features maps to features associated specifically to our dataset.\n",
        "\n",
        "Note: this should only be attempted after you have trained the top-level classifier with the pre-trained model set to non-trainable. If you add a randomly initialized classifier on top of a pre-trained model and attempt to train all layers jointly, the magnitude of the gradient updates will be too large (due to the random weights from the classifier) and your pre-trained model will just forget everything it has learned.\n",
        "\n",
        "Additionally, the reasoning behind fine-tuning the top layers of the pre-trained model rather than all layers of the pre-trained model is the following: in a ConvNet, the higher up a layer is, the more specialized it is. The first few layers in a ConvNet learned very simple and generic features, which generalize to almost all types of images. But as you go higher up, the features are increasingly more specific to the dataset that the model was trained on. The goal of fine-tuning is to adapt these specialized features to work with the new dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CPXnzUK0QonF"
      },
      "source": [
        "### Un-freeze the top layers of the model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rfxv_ifotQak"
      },
      "source": [
        "All we need to do is unfreeze the `base_model`, and set the bottom layers be un-trainable. Then, recompile the model (necessary for these changes to take effect), and resume training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4nzcagVitLQm"
      },
      "outputs": [],
      "source": [
        "base_model.trainable = True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-4HgVAacRs5v"
      },
      "outputs": [],
      "source": [
        "# Let's take a look to see how many layers are in the base model\n",
        "print(\"Number of layers in the base model: \", len(base_model.layers))\n",
        "\n",
        "# Fine tune from this layer onwards\n",
        "fine_tune_at = 100\n",
        "\n",
        "# Freeze all the layers before the `fine_tune_at` layer\n",
        "for layer in base_model.layers[:fine_tune_at]:\n",
        "  layer.trainable =  False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Uk1dgsxT0IS"
      },
      "source": [
        "### Compile the model\n",
        "\n",
        "Compile the model using a much-lower training rate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NtUnaz0WUDva"
      },
      "outputs": [],
      "source": [
        "model.compile(optimizer = tf.keras.optimizers.RMSprop(lr=2e-5),\n",
        "              loss='binary_crossentropy',\n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WwBWy7J2kZvA"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bNXelbMQtonr"
      },
      "outputs": [],
      "source": [
        "len(model.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4G5O4jd6TuAG"
      },
      "source": [
        "### Continue Train the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0foWUN-yDLo_"
      },
      "source": [
        "If you trained to convergence earlier, this will get you a few percent more accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ECQLkAsFTlun"
      },
      "outputs": [],
      "source": [
        "history_fine = model.fit_generator(train_generator,\n",
        "                                   steps_per_epoch = steps_per_epoch,\n",
        "                                   epochs=epochs,\n",
        "                                   workers=4,\n",
        "                                   validation_data=validation_generator,\n",
        "                                   validation_steps=validation_steps)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TfXEmsxQf6eP"
      },
      "source": [
        "### Learning curves\n",
        "\n",
        "Let's take a look at the learning curves of the training and validation accuracy / loss, when fine tuning the last few layers of the MobileNet V2 base model, as well as the classifier on top of it. Note the validation loss much higher than the training loss which means there maybe some overfitting.\n",
        "\n",
        "**Note**: the training dataset is fairly small, and is similar to the original datasets that MobileNet V2 was trained on, so fine-tuning may result in overfitting.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DNtfNZKlInGT"
      },
      "source": [
        "If you train to convergence (`epochs=50`) the resulting graph should look like this:\n",
        "\n",
        "![After fine tuning the model nearly reaches 98% accuracy](https://github.com/tensorflow/docs/blob/master/site/en/r1/tutorials/images/images/fine_tuning.png?raw=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PpA8PlpQKygw"
      },
      "outputs": [],
      "source": [
        "acc += history_fine.history['accuracy']\n",
        "val_acc += history_fine.history['val_accuracy']\n",
        "\n",
        "loss += history_fine.history['loss']\n",
        "val_loss += history_fine.history['val_loss']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "chW103JUItdk"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(8, 8))\n",
        "plt.subplot(2, 1, 1)\n",
        "plt.plot(acc, label='Training Accuracy')\n",
        "plt.plot(val_acc, label='Validation Accuracy')\n",
        "plt.ylim([0.9, 1])\n",
        "plt.plot([epochs-1,epochs-1], plt.ylim(), label='Start Fine Tuning')\n",
        "plt.legend(loc='lower right')\n",
        "plt.title('Training and Validation Accuracy')\n",
        "\n",
        "plt.subplot(2, 1, 2)\n",
        "plt.plot(loss, label='Training Loss')\n",
        "plt.plot(val_loss, label='Validation Loss')\n",
        "plt.ylim([0, 0.2])\n",
        "plt.plot([epochs-1,epochs-1], plt.ylim(), label='Start Fine Tuning')\n",
        "plt.legend(loc='upper right')\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tTrNmiH5vF4A"
      },
      "outputs": [],
      "source": [
        "acc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZzvAvWZevLWd"
      },
      "outputs": [],
      "source": [
        "loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_TZTwG7nhm0C"
      },
      "source": [
        "# Key takeaways\n",
        "In summary here is what we covered in this tutorial on how to do transfer learning using a pre-trained model to improve accuracy:\n",
        "* Using a pre-trained model for **feature extraction** - when working with a small dataset, it is common to leverage the features learned by a model trained on a larger dataset in the same domain. This is done by instantiating the pre-trained model and adding a fully connected classifier on top. The pre-trained model is \"frozen\" and only the weights of the classifier are updated during training.\n",
        "In this case, the convolutional base extracts all the features associated with each image and we train a classifier that determines, given these set of features to which class it belongs.\n",
        "* **Fine-tuning** a pre-trained model - to further improve performance, one might want to repurpose the top-level layers of the pre-trained models to the new dataset via fine-tuning.\n",
        "In this case, we tune our weights such that we learn highly specified and high level features specific to our dataset. This only make sense when the training dataset is large and very similar to the original dataset that the pre-trained model was trained on.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "transfer_learning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
